Skip to content

Commit 1baa8f1

Browse files
committed
ensure consistent topic and topic tag when classifying
1 parent 9b27d62 commit 1baa8f1

File tree

2 files changed

+193
-17
lines changed

2 files changed

+193
-17
lines changed

kitsune/questions/tests/test_utils.py

Lines changed: 165 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,11 @@
11
from copy import deepcopy
22

3+
from django.contrib.contenttypes.models import ContentType
34
from parameterized import parameterized
45

6+
from kitsune.flagit.models import FlaggedObject
7+
from kitsune.llm.questions.classifiers import ModerationAction
8+
from kitsune.products.tests import TopicFactory
59
from kitsune.questions.models import Answer, Question
610
from kitsune.questions.tests import AnswerFactory, QuestionFactory
711
from kitsune.questions.utils import (
@@ -10,10 +14,12 @@
1014
num_answers,
1115
num_questions,
1216
num_solutions,
17+
process_classification_result,
1318
remove_pii,
1419
remove_home_dir_pii,
1520
)
1621
from kitsune.sumo.tests import TestCase
22+
from kitsune.users.models import Profile
1723
from kitsune.users.tests import UserFactory
1824

1925

@@ -209,3 +215,162 @@ def test_remove_pii(self):
209215
] = "C:\\Users\\<USERNAME>\\AppData\\Local\\Mozilla\\Firefox"
210216
remove_pii(data)
211217
self.assertDictEqual(data, expected)
218+
219+
220+
class ProcessClassificationResultTests(TestCase):
221+
222+
def setUp(self):
223+
self.topic1 = TopicFactory()
224+
self.topic2 = TopicFactory()
225+
self.sumo_bot = Profile.get_sumo_bot()
226+
227+
def test_spam_result(self):
228+
question = QuestionFactory(topic=self.topic1)
229+
classification_result = dict(
230+
action=ModerationAction.SPAM,
231+
)
232+
self.assertFalse(question.is_spam)
233+
self.assertIsNone(question.marked_as_spam)
234+
self.assertIsNone(question.marked_as_spam_by)
235+
self.assertEqual(question.topic, self.topic1)
236+
237+
process_classification_result(question, classification_result)
238+
239+
self.assertTrue(question.is_spam)
240+
self.assertIsNotNone(question.marked_as_spam)
241+
self.assertEqual(question.marked_as_spam_by, self.sumo_bot)
242+
243+
def test_flagged_result(self):
244+
question = QuestionFactory(topic=self.topic1)
245+
classification_result = dict(
246+
action=ModerationAction.FLAG_REVIEW,
247+
spam_result=dict(reason="I think it is spam?"),
248+
)
249+
250+
q_ct = ContentType.objects.get_for_model(question)
251+
252+
self.assertFalse(question.is_spam)
253+
self.assertFalse(
254+
FlaggedObject.objects.filter(content_type=q_ct, object_id=question.id).exists()
255+
)
256+
self.assertEqual(question.topic, self.topic1)
257+
258+
process_classification_result(question, classification_result)
259+
260+
self.assertFalse(question.is_spam)
261+
self.assertEqual(question.topic, self.topic1)
262+
self.assertTrue(
263+
FlaggedObject.objects.filter(
264+
content_type=q_ct,
265+
object_id=question.id,
266+
creator=self.sumo_bot,
267+
reason=FlaggedObject.REASON_SPAM,
268+
status=FlaggedObject.FLAG_PENDING,
269+
notes__contains="I think it is spam?",
270+
).exists()
271+
)
272+
273+
def test_topic_result_with_change(self):
274+
question = QuestionFactory(topic=self.topic1, tags=[self.topic1.slug])
275+
classification_result = dict(
276+
action=ModerationAction.NOT_SPAM,
277+
topic_result=dict(
278+
topic=self.topic2.title,
279+
reason="Dude, it is so topic2.",
280+
),
281+
)
282+
283+
q_ct = ContentType.objects.get_for_model(question)
284+
285+
self.assertFalse(question.is_spam)
286+
self.assertFalse(
287+
FlaggedObject.objects.filter(content_type=q_ct, object_id=question.id).exists()
288+
)
289+
self.assertEqual(question.topic, self.topic1)
290+
self.assertEqual(set(tag.name for tag in question.my_tags), {self.topic1.slug})
291+
292+
process_classification_result(question, classification_result)
293+
294+
self.assertFalse(question.is_spam)
295+
self.assertEqual(question.topic, self.topic2)
296+
self.assertEqual(set(tag.name for tag in question.my_tags), {self.topic2.slug})
297+
self.assertTrue(
298+
FlaggedObject.objects.filter(
299+
content_type=q_ct,
300+
object_id=question.id,
301+
creator=self.sumo_bot,
302+
status=FlaggedObject.FLAG_ACCEPTED,
303+
reason=FlaggedObject.REASON_CONTENT_MODERATION,
304+
notes__contains="Dude, it is so topic2.",
305+
).exists()
306+
)
307+
308+
def test_topic_result_with_no_initial_topic(self):
309+
question = QuestionFactory(topic=None)
310+
classification_result = dict(
311+
action=ModerationAction.NOT_SPAM,
312+
topic_result=dict(
313+
topic=self.topic2.title,
314+
reason="Dude, it is so topic2.",
315+
),
316+
)
317+
318+
q_ct = ContentType.objects.get_for_model(question)
319+
320+
self.assertFalse(question.is_spam)
321+
self.assertFalse(
322+
FlaggedObject.objects.filter(content_type=q_ct, object_id=question.id).exists()
323+
)
324+
self.assertIsNone(question.topic)
325+
self.assertFalse(question.my_tags)
326+
327+
process_classification_result(question, classification_result)
328+
329+
self.assertFalse(question.is_spam)
330+
self.assertEqual(question.topic, self.topic2)
331+
self.assertEqual(set(tag.name for tag in question.my_tags), {self.topic2.slug})
332+
self.assertTrue(
333+
FlaggedObject.objects.filter(
334+
content_type=q_ct,
335+
object_id=question.id,
336+
creator=self.sumo_bot,
337+
status=FlaggedObject.FLAG_ACCEPTED,
338+
reason=FlaggedObject.REASON_CONTENT_MODERATION,
339+
notes__contains="Dude, it is so topic2.",
340+
).exists()
341+
)
342+
343+
def test_topic_result_with_no_change(self):
344+
question = QuestionFactory(topic=self.topic1, tags=[self.topic1.slug])
345+
classification_result = dict(
346+
action=ModerationAction.NOT_SPAM,
347+
topic_result=dict(
348+
topic=self.topic1.title,
349+
reason="Dude, it is so topic1.",
350+
),
351+
)
352+
353+
q_ct = ContentType.objects.get_for_model(question)
354+
355+
self.assertFalse(question.is_spam)
356+
self.assertFalse(
357+
FlaggedObject.objects.filter(content_type=q_ct, object_id=question.id).exists()
358+
)
359+
self.assertEqual(question.topic, self.topic1)
360+
self.assertEqual(set(tag.name for tag in question.my_tags), {self.topic1.slug})
361+
362+
process_classification_result(question, classification_result)
363+
364+
self.assertFalse(question.is_spam)
365+
self.assertEqual(question.topic, self.topic1)
366+
self.assertEqual(set(tag.name for tag in question.my_tags), {self.topic1.slug})
367+
self.assertTrue(
368+
FlaggedObject.objects.filter(
369+
content_type=q_ct,
370+
object_id=question.id,
371+
creator=self.sumo_bot,
372+
status=FlaggedObject.FLAG_ACCEPTED,
373+
reason=FlaggedObject.REASON_CONTENT_MODERATION,
374+
notes__contains="Dude, it is so topic1.",
375+
).exists()
376+
)

kitsune/questions/utils.py

Lines changed: 28 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,14 @@
77
from django.contrib.auth.models import User
88
from django.contrib.contenttypes.models import ContentType
99
from django.contrib.sessions.backends.base import SessionBase
10+
from django.db import transaction
1011

1112
from kitsune.flagit.models import FlaggedObject
1213
from kitsune.llm.questions.classifiers import ModerationAction
1314
from kitsune.products.models import Product, Topic
1415
from kitsune.questions.models import Answer, Question
16+
17+
# from kitsune.tags.models import SumoTag
1518
from kitsune.users.models import Profile
1619
from kitsune.wiki.utils import get_featured_articles as kb_get_featured_articles
1720
from kitsune.wiki.utils import has_visited_kb
@@ -186,20 +189,28 @@ def process_classification_result(
186189
reason=FlaggedObject.REASON_SPAM,
187190
)
188191
case _:
189-
if topic_title := result["topic_result"].get("topic"):
190-
try:
191-
topic = Topic.active.get(title=topic_title, visible=True)
192-
except (Topic.DoesNotExist, Topic.MultipleObjectsReturned):
193-
return
194-
else:
195-
flag_question(
196-
question,
197-
by_user=sumo_bot,
198-
notes=(
199-
"LLM classified as {topic.title}, for the following reason:\n"
200-
f"{result['topic_result']['reason']}"
201-
),
202-
status=FlaggedObject.FLAG_ACCEPTED,
203-
)
204-
question.topic = topic
205-
question.save()
192+
if not (topic_title := result["topic_result"].get("topic")):
193+
return
194+
195+
try:
196+
topic = Topic.active.get(title=topic_title, visible=True)
197+
except (Topic.DoesNotExist, Topic.MultipleObjectsReturned):
198+
return
199+
200+
with transaction.atomic():
201+
flag_question(
202+
question,
203+
by_user=sumo_bot,
204+
notes=(
205+
"LLM classified as {topic.title}, for the following reason:\n"
206+
f"{result['topic_result']['reason']}"
207+
),
208+
status=FlaggedObject.FLAG_ACCEPTED,
209+
)
210+
if question.topic:
211+
question.tags.remove(question.topic.slug)
212+
question.topic = topic
213+
question.save()
214+
question.tags.add(topic.slug)
215+
216+
question.clear_cached_tags()

0 commit comments

Comments
 (0)